e2e_asr_common.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. #!/usr/bin/env python3
  2. # encoding: utf-8
  3. # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Common functions for ASR."""
  6. from typing import List, Optional, Tuple
  7. import json
  8. import logging
  9. import sys
  10. from itertools import groupby
  11. import numpy as np
  12. import six
  13. import torch
  14. from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
  15. from funasr.models.joint_net.joint_network import JointNetwork
  16. def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
  17. """End detection.
  18. described in Eq. (50) of S. Watanabe et al
  19. "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
  20. :param ended_hyps:
  21. :param i:
  22. :param M:
  23. :param D_end:
  24. :return:
  25. """
  26. if len(ended_hyps) == 0:
  27. return False
  28. count = 0
  29. best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
  30. for m in six.moves.range(M):
  31. # get ended_hyps with their length is i - m
  32. hyp_length = i - m
  33. hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
  34. if len(hyps_same_length) > 0:
  35. best_hyp_same_length = sorted(
  36. hyps_same_length, key=lambda x: x["score"], reverse=True
  37. )[0]
  38. if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
  39. count += 1
  40. if count == M:
  41. return True
  42. else:
  43. return False
  44. # TODO(takaaki-hori): add different smoothing methods
  45. def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
  46. """Obtain label distribution for loss smoothing.
  47. :param odim:
  48. :param lsm_type:
  49. :param blank:
  50. :param transcript:
  51. :return:
  52. """
  53. if transcript is not None:
  54. with open(transcript, "rb") as f:
  55. trans_json = json.load(f)["utts"]
  56. if lsm_type == "unigram":
  57. assert transcript is not None, (
  58. "transcript is required for %s label smoothing" % lsm_type
  59. )
  60. labelcount = np.zeros(odim)
  61. for k, v in trans_json.items():
  62. ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
  63. # to avoid an error when there is no text in an uttrance
  64. if len(ids) > 0:
  65. labelcount[ids] += 1
  66. labelcount[odim - 1] = len(transcript) # count <eos>
  67. labelcount[labelcount == 0] = 1 # flooring
  68. labelcount[blank] = 0 # remove counts for blank
  69. labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
  70. else:
  71. logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
  72. sys.exit()
  73. return labeldist
  74. def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
  75. """Return the output size of the VGG frontend.
  76. :param in_channel: input channel size
  77. :param out_channel: output channel size
  78. :return: output size
  79. :rtype int
  80. """
  81. idim = idim / in_channel
  82. idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
  83. idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
  84. return int(idim) * out_channel # numer of channels
  85. class ErrorCalculator(object):
  86. """Calculate CER and WER for E2E_ASR and CTC models during training.
  87. :param y_hats: numpy array with predicted text
  88. :param y_pads: numpy array with true (target) text
  89. :param char_list:
  90. :param sym_space:
  91. :param sym_blank:
  92. :return:
  93. """
  94. def __init__(
  95. self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
  96. ):
  97. """Construct an ErrorCalculator object."""
  98. super(ErrorCalculator, self).__init__()
  99. self.report_cer = report_cer
  100. self.report_wer = report_wer
  101. self.char_list = char_list
  102. self.space = sym_space
  103. self.blank = sym_blank
  104. self.idx_blank = self.char_list.index(self.blank)
  105. if self.space in self.char_list:
  106. self.idx_space = self.char_list.index(self.space)
  107. else:
  108. self.idx_space = None
  109. def __call__(self, ys_hat, ys_pad, is_ctc=False):
  110. """Calculate sentence-level WER/CER score.
  111. :param torch.Tensor ys_hat: prediction (batch, seqlen)
  112. :param torch.Tensor ys_pad: reference (batch, seqlen)
  113. :param bool is_ctc: calculate CER score for CTC
  114. :return: sentence-level WER score
  115. :rtype float
  116. :return: sentence-level CER score
  117. :rtype float
  118. """
  119. cer, wer = None, None
  120. if is_ctc:
  121. return self.calculate_cer_ctc(ys_hat, ys_pad)
  122. elif not self.report_cer and not self.report_wer:
  123. return cer, wer
  124. seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
  125. if self.report_cer:
  126. cer = self.calculate_cer(seqs_hat, seqs_true)
  127. if self.report_wer:
  128. wer = self.calculate_wer(seqs_hat, seqs_true)
  129. return cer, wer
  130. def calculate_cer_ctc(self, ys_hat, ys_pad):
  131. """Calculate sentence-level CER score for CTC.
  132. :param torch.Tensor ys_hat: prediction (batch, seqlen)
  133. :param torch.Tensor ys_pad: reference (batch, seqlen)
  134. :return: average sentence-level CER score
  135. :rtype float
  136. """
  137. import editdistance
  138. cers, char_ref_lens = [], []
  139. for i, y in enumerate(ys_hat):
  140. y_hat = [x[0] for x in groupby(y)]
  141. y_true = ys_pad[i]
  142. seq_hat, seq_true = [], []
  143. for idx in y_hat:
  144. idx = int(idx)
  145. if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
  146. seq_hat.append(self.char_list[int(idx)])
  147. for idx in y_true:
  148. idx = int(idx)
  149. if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
  150. seq_true.append(self.char_list[int(idx)])
  151. hyp_chars = "".join(seq_hat)
  152. ref_chars = "".join(seq_true)
  153. if len(ref_chars) > 0:
  154. cers.append(editdistance.eval(hyp_chars, ref_chars))
  155. char_ref_lens.append(len(ref_chars))
  156. cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
  157. return cer_ctc
  158. def convert_to_char(self, ys_hat, ys_pad):
  159. """Convert index to character.
  160. :param torch.Tensor seqs_hat: prediction (batch, seqlen)
  161. :param torch.Tensor seqs_true: reference (batch, seqlen)
  162. :return: token list of prediction
  163. :rtype list
  164. :return: token list of reference
  165. :rtype list
  166. """
  167. seqs_hat, seqs_true = [], []
  168. for i, y_hat in enumerate(ys_hat):
  169. y_true = ys_pad[i]
  170. eos_true = np.where(y_true == -1)[0]
  171. ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
  172. # NOTE: padding index (-1) in y_true is used to pad y_hat
  173. seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
  174. seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
  175. seq_hat_text = "".join(seq_hat).replace(self.space, " ")
  176. seq_hat_text = seq_hat_text.replace(self.blank, "")
  177. seq_true_text = "".join(seq_true).replace(self.space, " ")
  178. seqs_hat.append(seq_hat_text)
  179. seqs_true.append(seq_true_text)
  180. return seqs_hat, seqs_true
  181. def calculate_cer(self, seqs_hat, seqs_true):
  182. """Calculate sentence-level CER score.
  183. :param list seqs_hat: prediction
  184. :param list seqs_true: reference
  185. :return: average sentence-level CER score
  186. :rtype float
  187. """
  188. import editdistance
  189. char_eds, char_ref_lens = [], []
  190. for i, seq_hat_text in enumerate(seqs_hat):
  191. seq_true_text = seqs_true[i]
  192. hyp_chars = seq_hat_text.replace(" ", "")
  193. ref_chars = seq_true_text.replace(" ", "")
  194. char_eds.append(editdistance.eval(hyp_chars, ref_chars))
  195. char_ref_lens.append(len(ref_chars))
  196. return float(sum(char_eds)) / sum(char_ref_lens)
  197. def calculate_wer(self, seqs_hat, seqs_true):
  198. """Calculate sentence-level WER score.
  199. :param list seqs_hat: prediction
  200. :param list seqs_true: reference
  201. :return: average sentence-level WER score
  202. :rtype float
  203. """
  204. import editdistance
  205. word_eds, word_ref_lens = [], []
  206. for i, seq_hat_text in enumerate(seqs_hat):
  207. seq_true_text = seqs_true[i]
  208. hyp_words = seq_hat_text.split()
  209. ref_words = seq_true_text.split()
  210. word_eds.append(editdistance.eval(hyp_words, ref_words))
  211. word_ref_lens.append(len(ref_words))
  212. return float(sum(word_eds)) / sum(word_ref_lens)
  213. class ErrorCalculatorTransducer:
  214. """Calculate CER and WER for transducer models.
  215. Args:
  216. decoder: Decoder module.
  217. joint_network: Joint Network module.
  218. token_list: List of token units.
  219. sym_space: Space symbol.
  220. sym_blank: Blank symbol.
  221. report_cer: Whether to compute CER.
  222. report_wer: Whether to compute WER.
  223. """
  224. def __init__(
  225. self,
  226. decoder,
  227. joint_network: JointNetwork,
  228. token_list: List[int],
  229. sym_space: str,
  230. sym_blank: str,
  231. report_cer: bool = False,
  232. report_wer: bool = False,
  233. ) -> None:
  234. """Construct an ErrorCalculatorTransducer object."""
  235. super().__init__()
  236. self.beam_search = BeamSearchTransducer(
  237. decoder=decoder,
  238. joint_network=joint_network,
  239. beam_size=1,
  240. search_type="default",
  241. score_norm=False,
  242. )
  243. self.decoder = decoder
  244. self.token_list = token_list
  245. self.space = sym_space
  246. self.blank = sym_blank
  247. self.report_cer = report_cer
  248. self.report_wer = report_wer
  249. def __call__(
  250. self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
  251. ) -> Tuple[Optional[float], Optional[float]]:
  252. """Calculate sentence-level WER or/and CER score for Transducer model.
  253. Args:
  254. encoder_out: Encoder output sequences. (B, T, D_enc)
  255. target: Target label ID sequences. (B, L)
  256. encoder_out_lens: Encoder output sequences length. (B,)
  257. Returns:
  258. : Sentence-level CER score.
  259. : Sentence-level WER score.
  260. """
  261. cer, wer = None, None
  262. batchsize = int(encoder_out.size(0))
  263. encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
  264. batch_nbest = [
  265. self.beam_search(encoder_out[b][: encoder_out_lens[b]])
  266. for b in range(batchsize)
  267. ]
  268. pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
  269. char_pred, char_target = self.convert_to_char(pred, target)
  270. if self.report_cer:
  271. cer = self.calculate_cer(char_pred, char_target)
  272. if self.report_wer:
  273. wer = self.calculate_wer(char_pred, char_target)
  274. return cer, wer
  275. def convert_to_char(
  276. self, pred: torch.Tensor, target: torch.Tensor
  277. ) -> Tuple[List, List]:
  278. """Convert label ID sequences to character sequences.
  279. Args:
  280. pred: Prediction label ID sequences. (B, U)
  281. target: Target label ID sequences. (B, L)
  282. Returns:
  283. char_pred: Prediction character sequences. (B, ?)
  284. char_target: Target character sequences. (B, ?)
  285. """
  286. char_pred, char_target = [], []
  287. for i, pred_i in enumerate(pred):
  288. char_pred_i = [self.token_list[int(h)] for h in pred_i]
  289. char_target_i = [self.token_list[int(r)] for r in target[i]]
  290. char_pred_i = "".join(char_pred_i).replace(self.space, " ")
  291. char_pred_i = char_pred_i.replace(self.blank, "")
  292. char_target_i = "".join(char_target_i).replace(self.space, " ")
  293. char_target_i = char_target_i.replace(self.blank, "")
  294. char_pred.append(char_pred_i)
  295. char_target.append(char_target_i)
  296. return char_pred, char_target
  297. def calculate_cer(
  298. self, char_pred: torch.Tensor, char_target: torch.Tensor
  299. ) -> float:
  300. """Calculate sentence-level CER score.
  301. Args:
  302. char_pred: Prediction character sequences. (B, ?)
  303. char_target: Target character sequences. (B, ?)
  304. Returns:
  305. : Average sentence-level CER score.
  306. """
  307. import editdistance
  308. distances, lens = [], []
  309. for i, char_pred_i in enumerate(char_pred):
  310. pred = char_pred_i.replace(" ", "")
  311. target = char_target[i].replace(" ", "")
  312. distances.append(editdistance.eval(pred, target))
  313. lens.append(len(target))
  314. return float(sum(distances)) / sum(lens)
  315. def calculate_wer(
  316. self, char_pred: torch.Tensor, char_target: torch.Tensor
  317. ) -> float:
  318. """Calculate sentence-level WER score.
  319. Args:
  320. char_pred: Prediction character sequences. (B, ?)
  321. char_target: Target character sequences. (B, ?)
  322. Returns:
  323. : Average sentence-level WER score
  324. """
  325. import editdistance
  326. distances, lens = [], []
  327. for i, char_pred_i in enumerate(char_pred):
  328. pred = char_pred_i.replace("▁", " ").split()
  329. target = char_target[i].replace("▁", " ").split()
  330. distances.append(editdistance.eval(pred, target))
  331. lens.append(len(target))
  332. return float(sum(distances)) / sum(lens)