model.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import time
  6. import torch
  7. import logging
  8. from contextlib import contextmanager
  9. from typing import Dict, Optional, Tuple
  10. from distutils.version import LooseVersion
  11. from funasr.register import tables
  12. from funasr.utils import postprocess_utils
  13. from funasr.utils.datadir_writer import DatadirWriter
  14. from funasr.train_utils.device_funcs import force_gatherable
  15. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  16. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  17. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  18. from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
  19. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  20. from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
  21. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  22. from torch.cuda.amp import autocast
  23. else:
  24. # Nothing to do if torch<1.6.0
  25. @contextmanager
  26. def autocast(enabled=True):
  27. yield
  28. @tables.register("model_classes", "BAT") # TODO: BAT training
  29. class BAT(torch.nn.Module):
  30. def __init__(
  31. self,
  32. frontend: Optional[str] = None,
  33. frontend_conf: Optional[Dict] = None,
  34. specaug: Optional[str] = None,
  35. specaug_conf: Optional[Dict] = None,
  36. normalize: str = None,
  37. normalize_conf: Optional[Dict] = None,
  38. encoder: str = None,
  39. encoder_conf: Optional[Dict] = None,
  40. decoder: str = None,
  41. decoder_conf: Optional[Dict] = None,
  42. joint_network: str = None,
  43. joint_network_conf: Optional[Dict] = None,
  44. transducer_weight: float = 1.0,
  45. fastemit_lambda: float = 0.0,
  46. auxiliary_ctc_weight: float = 0.0,
  47. auxiliary_ctc_dropout_rate: float = 0.0,
  48. auxiliary_lm_loss_weight: float = 0.0,
  49. auxiliary_lm_loss_smoothing: float = 0.0,
  50. input_size: int = 80,
  51. vocab_size: int = -1,
  52. ignore_id: int = -1,
  53. blank_id: int = 0,
  54. sos: int = 1,
  55. eos: int = 2,
  56. lsm_weight: float = 0.0,
  57. length_normalized_loss: bool = False,
  58. # report_cer: bool = True,
  59. # report_wer: bool = True,
  60. # sym_space: str = "<space>",
  61. # sym_blank: str = "<blank>",
  62. # extract_feats_in_collect_stats: bool = True,
  63. share_embedding: bool = False,
  64. # preencoder: Optional[AbsPreEncoder] = None,
  65. # postencoder: Optional[AbsPostEncoder] = None,
  66. **kwargs,
  67. ):
  68. super().__init__()
  69. if specaug is not None:
  70. specaug_class = tables.specaug_classes.get(specaug)
  71. specaug = specaug_class(**specaug_conf)
  72. if normalize is not None:
  73. normalize_class = tables.normalize_classes.get(normalize)
  74. normalize = normalize_class(**normalize_conf)
  75. encoder_class = tables.encoder_classes.get(encoder)
  76. encoder = encoder_class(input_size=input_size, **encoder_conf)
  77. encoder_output_size = encoder.output_size()
  78. decoder_class = tables.decoder_classes.get(decoder)
  79. decoder = decoder_class(
  80. vocab_size=vocab_size,
  81. **decoder_conf,
  82. )
  83. decoder_output_size = decoder.output_size
  84. joint_network_class = tables.joint_network_classes.get(joint_network)
  85. joint_network = joint_network_class(
  86. vocab_size,
  87. encoder_output_size,
  88. decoder_output_size,
  89. **joint_network_conf,
  90. )
  91. self.criterion_transducer = None
  92. self.error_calculator = None
  93. self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
  94. self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
  95. if self.use_auxiliary_ctc:
  96. self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
  97. self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
  98. if self.use_auxiliary_lm_loss:
  99. self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
  100. self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
  101. self.transducer_weight = transducer_weight
  102. self.fastemit_lambda = fastemit_lambda
  103. self.auxiliary_ctc_weight = auxiliary_ctc_weight
  104. self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
  105. self.blank_id = blank_id
  106. self.sos = sos if sos is not None else vocab_size - 1
  107. self.eos = eos if eos is not None else vocab_size - 1
  108. self.vocab_size = vocab_size
  109. self.ignore_id = ignore_id
  110. self.frontend = frontend
  111. self.specaug = specaug
  112. self.normalize = normalize
  113. self.encoder = encoder
  114. self.decoder = decoder
  115. self.joint_network = joint_network
  116. self.criterion_att = LabelSmoothingLoss(
  117. size=vocab_size,
  118. padding_idx=ignore_id,
  119. smoothing=lsm_weight,
  120. normalize_length=length_normalized_loss,
  121. )
  122. self.length_normalized_loss = length_normalized_loss
  123. self.beam_search = None
  124. self.ctc = None
  125. self.ctc_weight = 0.0
  126. def forward(
  127. self,
  128. speech: torch.Tensor,
  129. speech_lengths: torch.Tensor,
  130. text: torch.Tensor,
  131. text_lengths: torch.Tensor,
  132. **kwargs,
  133. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  134. """Encoder + Decoder + Calc loss
  135. Args:
  136. speech: (Batch, Length, ...)
  137. speech_lengths: (Batch, )
  138. text: (Batch, Length)
  139. text_lengths: (Batch,)
  140. """
  141. if len(text_lengths.size()) > 1:
  142. text_lengths = text_lengths[:, 0]
  143. if len(speech_lengths.size()) > 1:
  144. speech_lengths = speech_lengths[:, 0]
  145. batch_size = speech.shape[0]
  146. # 1. Encoder
  147. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  148. if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
  149. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  150. chunk_outs=None)
  151. # 2. Transducer-related I/O preparation
  152. decoder_in, target, t_len, u_len = get_transducer_task_io(
  153. text,
  154. encoder_out_lens,
  155. ignore_id=self.ignore_id,
  156. )
  157. # 3. Decoder
  158. self.decoder.set_device(encoder_out.device)
  159. decoder_out = self.decoder(decoder_in, u_len)
  160. # 4. Joint Network
  161. joint_out = self.joint_network(
  162. encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
  163. )
  164. # 5. Losses
  165. loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
  166. encoder_out,
  167. joint_out,
  168. target,
  169. t_len,
  170. u_len,
  171. )
  172. loss_ctc, loss_lm = 0.0, 0.0
  173. if self.use_auxiliary_ctc:
  174. loss_ctc = self._calc_ctc_loss(
  175. encoder_out,
  176. target,
  177. t_len,
  178. u_len,
  179. )
  180. if self.use_auxiliary_lm_loss:
  181. loss_lm = self._calc_lm_loss(decoder_out, target)
  182. loss = (
  183. self.transducer_weight * loss_trans
  184. + self.auxiliary_ctc_weight * loss_ctc
  185. + self.auxiliary_lm_loss_weight * loss_lm
  186. )
  187. stats = dict(
  188. loss=loss.detach(),
  189. loss_transducer=loss_trans.detach(),
  190. aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
  191. aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
  192. cer_transducer=cer_trans,
  193. wer_transducer=wer_trans,
  194. )
  195. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  196. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  197. return loss, stats, weight
  198. def encode(
  199. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  200. ) -> Tuple[torch.Tensor, torch.Tensor]:
  201. """Frontend + Encoder. Note that this method is used by asr_inference.py
  202. Args:
  203. speech: (Batch, Length, ...)
  204. speech_lengths: (Batch, )
  205. ind: int
  206. """
  207. with autocast(False):
  208. # Data augmentation
  209. if self.specaug is not None and self.training:
  210. speech, speech_lengths = self.specaug(speech, speech_lengths)
  211. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  212. if self.normalize is not None:
  213. speech, speech_lengths = self.normalize(speech, speech_lengths)
  214. # Forward encoder
  215. # feats: (Batch, Length, Dim)
  216. # -> encoder_out: (Batch, Length2, Dim2)
  217. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  218. intermediate_outs = None
  219. if isinstance(encoder_out, tuple):
  220. intermediate_outs = encoder_out[1]
  221. encoder_out = encoder_out[0]
  222. if intermediate_outs is not None:
  223. return (encoder_out, intermediate_outs), encoder_out_lens
  224. return encoder_out, encoder_out_lens
  225. def _calc_transducer_loss(
  226. self,
  227. encoder_out: torch.Tensor,
  228. joint_out: torch.Tensor,
  229. target: torch.Tensor,
  230. t_len: torch.Tensor,
  231. u_len: torch.Tensor,
  232. ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
  233. """Compute Transducer loss.
  234. Args:
  235. encoder_out: Encoder output sequences. (B, T, D_enc)
  236. joint_out: Joint Network output sequences (B, T, U, D_joint)
  237. target: Target label ID sequences. (B, L)
  238. t_len: Encoder output sequences lengths. (B,)
  239. u_len: Target label ID sequences lengths. (B,)
  240. Return:
  241. loss_transducer: Transducer loss value.
  242. cer_transducer: Character error rate for Transducer.
  243. wer_transducer: Word Error Rate for Transducer.
  244. """
  245. if self.criterion_transducer is None:
  246. try:
  247. from warp_rnnt import rnnt_loss as RNNTLoss
  248. self.criterion_transducer = RNNTLoss
  249. except ImportError:
  250. logging.error(
  251. "warp-rnnt was not installed."
  252. "Please consult the installation documentation."
  253. )
  254. exit(1)
  255. log_probs = torch.log_softmax(joint_out, dim=-1)
  256. loss_transducer = self.criterion_transducer(
  257. log_probs,
  258. target,
  259. t_len,
  260. u_len,
  261. reduction="mean",
  262. blank=self.blank_id,
  263. fastemit_lambda=self.fastemit_lambda,
  264. gather=True,
  265. )
  266. if not self.training and (self.report_cer or self.report_wer):
  267. if self.error_calculator is None:
  268. from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
  269. self.error_calculator = ErrorCalculator(
  270. self.decoder,
  271. self.joint_network,
  272. self.token_list,
  273. self.sym_space,
  274. self.sym_blank,
  275. report_cer=self.report_cer,
  276. report_wer=self.report_wer,
  277. )
  278. cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
  279. return loss_transducer, cer_transducer, wer_transducer
  280. return loss_transducer, None, None
  281. def _calc_ctc_loss(
  282. self,
  283. encoder_out: torch.Tensor,
  284. target: torch.Tensor,
  285. t_len: torch.Tensor,
  286. u_len: torch.Tensor,
  287. ) -> torch.Tensor:
  288. """Compute CTC loss.
  289. Args:
  290. encoder_out: Encoder output sequences. (B, T, D_enc)
  291. target: Target label ID sequences. (B, L)
  292. t_len: Encoder output sequences lengths. (B,)
  293. u_len: Target label ID sequences lengths. (B,)
  294. Return:
  295. loss_ctc: CTC loss value.
  296. """
  297. ctc_in = self.ctc_lin(
  298. torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
  299. )
  300. ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
  301. target_mask = target != 0
  302. ctc_target = target[target_mask].cpu()
  303. with torch.backends.cudnn.flags(deterministic=True):
  304. loss_ctc = torch.nn.functional.ctc_loss(
  305. ctc_in,
  306. ctc_target,
  307. t_len,
  308. u_len,
  309. zero_infinity=True,
  310. reduction="sum",
  311. )
  312. loss_ctc /= target.size(0)
  313. return loss_ctc
  314. def _calc_lm_loss(
  315. self,
  316. decoder_out: torch.Tensor,
  317. target: torch.Tensor,
  318. ) -> torch.Tensor:
  319. """Compute LM loss.
  320. Args:
  321. decoder_out: Decoder output sequences. (B, U, D_dec)
  322. target: Target label ID sequences. (B, L)
  323. Return:
  324. loss_lm: LM loss value.
  325. """
  326. lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
  327. lm_target = target.view(-1).type(torch.int64)
  328. with torch.no_grad():
  329. true_dist = lm_loss_in.clone()
  330. true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
  331. # Ignore blank ID (0)
  332. ignore = lm_target == 0
  333. lm_target = lm_target.masked_fill(ignore, 0)
  334. true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
  335. loss_lm = torch.nn.functional.kl_div(
  336. torch.log_softmax(lm_loss_in, dim=1),
  337. true_dist,
  338. reduction="none",
  339. )
  340. loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
  341. 0
  342. )
  343. return loss_lm
  344. def init_beam_search(self,
  345. **kwargs,
  346. ):
  347. # 1. Build ASR model
  348. scorers = {}
  349. if self.ctc != None:
  350. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  351. scorers.update(
  352. ctc=ctc
  353. )
  354. token_list = kwargs.get("token_list")
  355. scorers.update(
  356. length_bonus=LengthBonus(len(token_list)),
  357. )
  358. # 3. Build ngram model
  359. # ngram is not supported now
  360. ngram = None
  361. scorers["ngram"] = ngram
  362. beam_search = BeamSearchTransducer(
  363. self.decoder,
  364. self.joint_network,
  365. kwargs.get("beam_size", 2),
  366. nbest=1,
  367. )
  368. # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  369. # for scorer in scorers.values():
  370. # if isinstance(scorer, torch.nn.Module):
  371. # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  372. self.beam_search = beam_search
  373. def inference(self,
  374. data_in: list,
  375. data_lengths: list=None,
  376. key: list=None,
  377. tokenizer=None,
  378. **kwargs,
  379. ):
  380. if kwargs.get("batch_size", 1) > 1:
  381. raise NotImplementedError("batch decoding is not implemented")
  382. # init beamsearch
  383. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  384. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  385. # if self.beam_search is None and (is_use_lm or is_use_ctc):
  386. logging.info("enable beam_search")
  387. self.init_beam_search(**kwargs)
  388. self.nbest = kwargs.get("nbest", 1)
  389. meta_data = {}
  390. # extract fbank feats
  391. time1 = time.perf_counter()
  392. audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
  393. time2 = time.perf_counter()
  394. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  395. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
  396. time3 = time.perf_counter()
  397. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  398. meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
  399. speech = speech.to(device=kwargs["device"])
  400. speech_lengths = speech_lengths.to(device=kwargs["device"])
  401. # Encoder
  402. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  403. if isinstance(encoder_out, tuple):
  404. encoder_out = encoder_out[0]
  405. # c. Passed the encoder result and the beam search
  406. nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
  407. nbest_hyps = nbest_hyps[: self.nbest]
  408. results = []
  409. b, n, d = encoder_out.size()
  410. for i in range(b):
  411. for nbest_idx, hyp in enumerate(nbest_hyps):
  412. ibest_writer = None
  413. if kwargs.get("output_dir") is not None:
  414. if not hasattr(self, "writer"):
  415. self.writer = DatadirWriter(kwargs.get("output_dir"))
  416. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  417. # remove sos/eos and get results
  418. last_pos = -1
  419. if isinstance(hyp.yseq, list):
  420. token_int = hyp.yseq#[1:last_pos]
  421. else:
  422. token_int = hyp.yseq#[1:last_pos].tolist()
  423. # remove blank symbol id, which is assumed to be 0
  424. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  425. # Change integer-ids to tokens
  426. token = tokenizer.ids2tokens(token_int)
  427. text = tokenizer.tokens2text(token)
  428. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  429. result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
  430. results.append(result_i)
  431. if ibest_writer is not None:
  432. ibest_writer["token"][key[i]] = " ".join(token)
  433. ibest_writer["text"][key[i]] = text
  434. ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
  435. return results, meta_data