e2e_asr_common.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #!/usr/bin/env python3
  2. # encoding: utf-8
  3. # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Common functions for ASR."""
  6. import json
  7. import logging
  8. import sys
  9. from itertools import groupby
  10. import numpy as np
  11. import six
  12. def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
  13. """End detection.
  14. described in Eq. (50) of S. Watanabe et al
  15. "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
  16. :param ended_hyps:
  17. :param i:
  18. :param M:
  19. :param D_end:
  20. :return:
  21. """
  22. if len(ended_hyps) == 0:
  23. return False
  24. count = 0
  25. best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
  26. for m in six.moves.range(M):
  27. # get ended_hyps with their length is i - m
  28. hyp_length = i - m
  29. hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
  30. if len(hyps_same_length) > 0:
  31. best_hyp_same_length = sorted(
  32. hyps_same_length, key=lambda x: x["score"], reverse=True
  33. )[0]
  34. if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
  35. count += 1
  36. if count == M:
  37. return True
  38. else:
  39. return False
  40. # TODO(takaaki-hori): add different smoothing methods
  41. def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
  42. """Obtain label distribution for loss smoothing.
  43. :param odim:
  44. :param lsm_type:
  45. :param blank:
  46. :param transcript:
  47. :return:
  48. """
  49. if transcript is not None:
  50. with open(transcript, "rb") as f:
  51. trans_json = json.load(f)["utts"]
  52. if lsm_type == "unigram":
  53. assert transcript is not None, (
  54. "transcript is required for %s label smoothing" % lsm_type
  55. )
  56. labelcount = np.zeros(odim)
  57. for k, v in trans_json.items():
  58. ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
  59. # to avoid an error when there is no text in an uttrance
  60. if len(ids) > 0:
  61. labelcount[ids] += 1
  62. labelcount[odim - 1] = len(transcript) # count <eos>
  63. labelcount[labelcount == 0] = 1 # flooring
  64. labelcount[blank] = 0 # remove counts for blank
  65. labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
  66. else:
  67. logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
  68. sys.exit()
  69. return labeldist
  70. def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
  71. """Return the output size of the VGG frontend.
  72. :param in_channel: input channel size
  73. :param out_channel: output channel size
  74. :return: output size
  75. :rtype int
  76. """
  77. idim = idim / in_channel
  78. idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
  79. idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
  80. return int(idim) * out_channel # numer of channels
  81. class ErrorCalculator(object):
  82. """Calculate CER and WER for E2E_ASR and CTC models during training.
  83. :param y_hats: numpy array with predicted text
  84. :param y_pads: numpy array with true (target) text
  85. :param char_list:
  86. :param sym_space:
  87. :param sym_blank:
  88. :return:
  89. """
  90. def __init__(
  91. self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
  92. ):
  93. """Construct an ErrorCalculator object."""
  94. super(ErrorCalculator, self).__init__()
  95. self.report_cer = report_cer
  96. self.report_wer = report_wer
  97. self.char_list = char_list
  98. self.space = sym_space
  99. self.blank = sym_blank
  100. self.idx_blank = self.char_list.index(self.blank)
  101. if self.space in self.char_list:
  102. self.idx_space = self.char_list.index(self.space)
  103. else:
  104. self.idx_space = None
  105. def __call__(self, ys_hat, ys_pad, is_ctc=False):
  106. """Calculate sentence-level WER/CER score.
  107. :param torch.Tensor ys_hat: prediction (batch, seqlen)
  108. :param torch.Tensor ys_pad: reference (batch, seqlen)
  109. :param bool is_ctc: calculate CER score for CTC
  110. :return: sentence-level WER score
  111. :rtype float
  112. :return: sentence-level CER score
  113. :rtype float
  114. """
  115. cer, wer = None, None
  116. if is_ctc:
  117. return self.calculate_cer_ctc(ys_hat, ys_pad)
  118. elif not self.report_cer and not self.report_wer:
  119. return cer, wer
  120. seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
  121. if self.report_cer:
  122. cer = self.calculate_cer(seqs_hat, seqs_true)
  123. if self.report_wer:
  124. wer = self.calculate_wer(seqs_hat, seqs_true)
  125. return cer, wer
  126. def calculate_cer_ctc(self, ys_hat, ys_pad):
  127. """Calculate sentence-level CER score for CTC.
  128. :param torch.Tensor ys_hat: prediction (batch, seqlen)
  129. :param torch.Tensor ys_pad: reference (batch, seqlen)
  130. :return: average sentence-level CER score
  131. :rtype float
  132. """
  133. import editdistance
  134. cers, char_ref_lens = [], []
  135. for i, y in enumerate(ys_hat):
  136. y_hat = [x[0] for x in groupby(y)]
  137. y_true = ys_pad[i]
  138. seq_hat, seq_true = [], []
  139. for idx in y_hat:
  140. idx = int(idx)
  141. if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
  142. seq_hat.append(self.char_list[int(idx)])
  143. for idx in y_true:
  144. idx = int(idx)
  145. if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
  146. seq_true.append(self.char_list[int(idx)])
  147. hyp_chars = "".join(seq_hat)
  148. ref_chars = "".join(seq_true)
  149. if len(ref_chars) > 0:
  150. cers.append(editdistance.eval(hyp_chars, ref_chars))
  151. char_ref_lens.append(len(ref_chars))
  152. cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
  153. return cer_ctc
  154. def convert_to_char(self, ys_hat, ys_pad):
  155. """Convert index to character.
  156. :param torch.Tensor seqs_hat: prediction (batch, seqlen)
  157. :param torch.Tensor seqs_true: reference (batch, seqlen)
  158. :return: token list of prediction
  159. :rtype list
  160. :return: token list of reference
  161. :rtype list
  162. """
  163. seqs_hat, seqs_true = [], []
  164. for i, y_hat in enumerate(ys_hat):
  165. y_true = ys_pad[i]
  166. eos_true = np.where(y_true == -1)[0]
  167. ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
  168. # NOTE: padding index (-1) in y_true is used to pad y_hat
  169. seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
  170. seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
  171. seq_hat_text = "".join(seq_hat).replace(self.space, " ")
  172. seq_hat_text = seq_hat_text.replace(self.blank, "")
  173. seq_true_text = "".join(seq_true).replace(self.space, " ")
  174. seqs_hat.append(seq_hat_text)
  175. seqs_true.append(seq_true_text)
  176. return seqs_hat, seqs_true
  177. def calculate_cer(self, seqs_hat, seqs_true):
  178. """Calculate sentence-level CER score.
  179. :param list seqs_hat: prediction
  180. :param list seqs_true: reference
  181. :return: average sentence-level CER score
  182. :rtype float
  183. """
  184. import editdistance
  185. char_eds, char_ref_lens = [], []
  186. for i, seq_hat_text in enumerate(seqs_hat):
  187. seq_true_text = seqs_true[i]
  188. hyp_chars = seq_hat_text.replace(" ", "")
  189. ref_chars = seq_true_text.replace(" ", "")
  190. char_eds.append(editdistance.eval(hyp_chars, ref_chars))
  191. char_ref_lens.append(len(ref_chars))
  192. return float(sum(char_eds)) / sum(char_ref_lens)
  193. def calculate_wer(self, seqs_hat, seqs_true):
  194. """Calculate sentence-level WER score.
  195. :param list seqs_hat: prediction
  196. :param list seqs_true: reference
  197. :return: average sentence-level WER score
  198. :rtype float
  199. """
  200. import editdistance
  201. word_eds, word_ref_lens = [], []
  202. for i, seq_hat_text in enumerate(seqs_hat):
  203. seq_true_text = seqs_true[i]
  204. hyp_words = seq_hat_text.split()
  205. ref_words = seq_true_text.split()
  206. word_eds.append(editdistance.eval(hyp_words, ref_words))
  207. word_ref_lens.append(len(ref_words))
  208. return float(sum(word_eds)) / sum(word_ref_lens)