| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399 |
- #!/usr/bin/env python3
- # encoding: utf-8
- # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Common functions for ASR."""
- from typing import List, Optional, Tuple
- import json
- import logging
- import sys
- from itertools import groupby
- import numpy as np
- import six
- import torch
- from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
- from funasr.models.joint_net.joint_network import JointNetwork
- def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
- """End detection.
- described in Eq. (50) of S. Watanabe et al
- "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
- :param ended_hyps:
- :param i:
- :param M:
- :param D_end:
- :return:
- """
- if len(ended_hyps) == 0:
- return False
- count = 0
- best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
- for m in six.moves.range(M):
- # get ended_hyps with their length is i - m
- hyp_length = i - m
- hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
- if len(hyps_same_length) > 0:
- best_hyp_same_length = sorted(
- hyps_same_length, key=lambda x: x["score"], reverse=True
- )[0]
- if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
- count += 1
- if count == M:
- return True
- else:
- return False
- # TODO(takaaki-hori): add different smoothing methods
- def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
- """Obtain label distribution for loss smoothing.
- :param odim:
- :param lsm_type:
- :param blank:
- :param transcript:
- :return:
- """
- if transcript is not None:
- with open(transcript, "rb") as f:
- trans_json = json.load(f)["utts"]
- if lsm_type == "unigram":
- assert transcript is not None, (
- "transcript is required for %s label smoothing" % lsm_type
- )
- labelcount = np.zeros(odim)
- for k, v in trans_json.items():
- ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
- # to avoid an error when there is no text in an uttrance
- if len(ids) > 0:
- labelcount[ids] += 1
- labelcount[odim - 1] = len(transcript) # count <eos>
- labelcount[labelcount == 0] = 1 # flooring
- labelcount[blank] = 0 # remove counts for blank
- labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
- else:
- logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
- sys.exit()
- return labeldist
- def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
- """Return the output size of the VGG frontend.
- :param in_channel: input channel size
- :param out_channel: output channel size
- :return: output size
- :rtype int
- """
- idim = idim / in_channel
- idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
- idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
- return int(idim) * out_channel # numer of channels
- class ErrorCalculator(object):
- """Calculate CER and WER for E2E_ASR and CTC models during training.
- :param y_hats: numpy array with predicted text
- :param y_pads: numpy array with true (target) text
- :param char_list:
- :param sym_space:
- :param sym_blank:
- :return:
- """
- def __init__(
- self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
- ):
- """Construct an ErrorCalculator object."""
- super(ErrorCalculator, self).__init__()
- self.report_cer = report_cer
- self.report_wer = report_wer
- self.char_list = char_list
- self.space = sym_space
- self.blank = sym_blank
- self.idx_blank = self.char_list.index(self.blank)
- if self.space in self.char_list:
- self.idx_space = self.char_list.index(self.space)
- else:
- self.idx_space = None
- def __call__(self, ys_hat, ys_pad, is_ctc=False):
- """Calculate sentence-level WER/CER score.
- :param torch.Tensor ys_hat: prediction (batch, seqlen)
- :param torch.Tensor ys_pad: reference (batch, seqlen)
- :param bool is_ctc: calculate CER score for CTC
- :return: sentence-level WER score
- :rtype float
- :return: sentence-level CER score
- :rtype float
- """
- cer, wer = None, None
- if is_ctc:
- return self.calculate_cer_ctc(ys_hat, ys_pad)
- elif not self.report_cer and not self.report_wer:
- return cer, wer
- seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
- if self.report_cer:
- cer = self.calculate_cer(seqs_hat, seqs_true)
- if self.report_wer:
- wer = self.calculate_wer(seqs_hat, seqs_true)
- return cer, wer
- def calculate_cer_ctc(self, ys_hat, ys_pad):
- """Calculate sentence-level CER score for CTC.
- :param torch.Tensor ys_hat: prediction (batch, seqlen)
- :param torch.Tensor ys_pad: reference (batch, seqlen)
- :return: average sentence-level CER score
- :rtype float
- """
- import editdistance
- cers, char_ref_lens = [], []
- for i, y in enumerate(ys_hat):
- y_hat = [x[0] for x in groupby(y)]
- y_true = ys_pad[i]
- seq_hat, seq_true = [], []
- for idx in y_hat:
- idx = int(idx)
- if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
- seq_hat.append(self.char_list[int(idx)])
- for idx in y_true:
- idx = int(idx)
- if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
- seq_true.append(self.char_list[int(idx)])
- hyp_chars = "".join(seq_hat)
- ref_chars = "".join(seq_true)
- if len(ref_chars) > 0:
- cers.append(editdistance.eval(hyp_chars, ref_chars))
- char_ref_lens.append(len(ref_chars))
- cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
- return cer_ctc
- def convert_to_char(self, ys_hat, ys_pad):
- """Convert index to character.
- :param torch.Tensor seqs_hat: prediction (batch, seqlen)
- :param torch.Tensor seqs_true: reference (batch, seqlen)
- :return: token list of prediction
- :rtype list
- :return: token list of reference
- :rtype list
- """
- seqs_hat, seqs_true = [], []
- for i, y_hat in enumerate(ys_hat):
- y_true = ys_pad[i]
- eos_true = np.where(y_true == -1)[0]
- ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
- # NOTE: padding index (-1) in y_true is used to pad y_hat
- seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
- seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
- seq_hat_text = "".join(seq_hat).replace(self.space, " ")
- seq_hat_text = seq_hat_text.replace(self.blank, "")
- seq_true_text = "".join(seq_true).replace(self.space, " ")
- seqs_hat.append(seq_hat_text)
- seqs_true.append(seq_true_text)
- return seqs_hat, seqs_true
- def calculate_cer(self, seqs_hat, seqs_true):
- """Calculate sentence-level CER score.
- :param list seqs_hat: prediction
- :param list seqs_true: reference
- :return: average sentence-level CER score
- :rtype float
- """
- import editdistance
- char_eds, char_ref_lens = [], []
- for i, seq_hat_text in enumerate(seqs_hat):
- seq_true_text = seqs_true[i]
- hyp_chars = seq_hat_text.replace(" ", "")
- ref_chars = seq_true_text.replace(" ", "")
- char_eds.append(editdistance.eval(hyp_chars, ref_chars))
- char_ref_lens.append(len(ref_chars))
- return float(sum(char_eds)) / sum(char_ref_lens)
- def calculate_wer(self, seqs_hat, seqs_true):
- """Calculate sentence-level WER score.
- :param list seqs_hat: prediction
- :param list seqs_true: reference
- :return: average sentence-level WER score
- :rtype float
- """
- import editdistance
- word_eds, word_ref_lens = [], []
- for i, seq_hat_text in enumerate(seqs_hat):
- seq_true_text = seqs_true[i]
- hyp_words = seq_hat_text.split()
- ref_words = seq_true_text.split()
- word_eds.append(editdistance.eval(hyp_words, ref_words))
- word_ref_lens.append(len(ref_words))
- return float(sum(word_eds)) / sum(word_ref_lens)
- class ErrorCalculatorTransducer:
- """Calculate CER and WER for transducer models.
- Args:
- decoder: Decoder module.
- joint_network: Joint Network module.
- token_list: List of token units.
- sym_space: Space symbol.
- sym_blank: Blank symbol.
- report_cer: Whether to compute CER.
- report_wer: Whether to compute WER.
- """
- def __init__(
- self,
- decoder,
- joint_network: JointNetwork,
- token_list: List[int],
- sym_space: str,
- sym_blank: str,
- report_cer: bool = False,
- report_wer: bool = False,
- ) -> None:
- """Construct an ErrorCalculatorTransducer object."""
- super().__init__()
- self.beam_search = BeamSearchTransducer(
- decoder=decoder,
- joint_network=joint_network,
- beam_size=1,
- search_type="default",
- score_norm=False,
- )
- self.decoder = decoder
- self.token_list = token_list
- self.space = sym_space
- self.blank = sym_blank
- self.report_cer = report_cer
- self.report_wer = report_wer
- def __call__(
- self, encoder_out: torch.Tensor, target: torch.Tensor
- ) -> Tuple[Optional[float], Optional[float]]:
- """Calculate sentence-level WER or/and CER score for Transducer model.
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- target: Target label ID sequences. (B, L)
- Returns:
- : Sentence-level CER score.
- : Sentence-level WER score.
- """
- cer, wer = None, None
- batchsize = int(encoder_out.size(0))
- encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
- batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
- pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
- char_pred, char_target = self.convert_to_char(pred, target)
- if self.report_cer:
- cer = self.calculate_cer(char_pred, char_target)
- if self.report_wer:
- wer = self.calculate_wer(char_pred, char_target)
- return cer, wer
- def convert_to_char(
- self, pred: torch.Tensor, target: torch.Tensor
- ) -> Tuple[List, List]:
- """Convert label ID sequences to character sequences.
- Args:
- pred: Prediction label ID sequences. (B, U)
- target: Target label ID sequences. (B, L)
- Returns:
- char_pred: Prediction character sequences. (B, ?)
- char_target: Target character sequences. (B, ?)
- """
- char_pred, char_target = [], []
- for i, pred_i in enumerate(pred):
- char_pred_i = [self.token_list[int(h)] for h in pred_i]
- char_target_i = [self.token_list[int(r)] for r in target[i]]
- char_pred_i = "".join(char_pred_i).replace(self.space, " ")
- char_pred_i = char_pred_i.replace(self.blank, "")
- char_target_i = "".join(char_target_i).replace(self.space, " ")
- char_target_i = char_target_i.replace(self.blank, "")
- char_pred.append(char_pred_i)
- char_target.append(char_target_i)
- return char_pred, char_target
- def calculate_cer(
- self, char_pred: torch.Tensor, char_target: torch.Tensor
- ) -> float:
- """Calculate sentence-level CER score.
- Args:
- char_pred: Prediction character sequences. (B, ?)
- char_target: Target character sequences. (B, ?)
- Returns:
- : Average sentence-level CER score.
- """
- import editdistance
- distances, lens = [], []
- for i, char_pred_i in enumerate(char_pred):
- pred = char_pred_i.replace(" ", "")
- target = char_target[i].replace(" ", "")
- distances.append(editdistance.eval(pred, target))
- lens.append(len(target))
- return float(sum(distances)) / sum(lens)
- def calculate_wer(
- self, char_pred: torch.Tensor, char_target: torch.Tensor
- ) -> float:
- """Calculate sentence-level WER score.
- Args:
- char_pred: Prediction character sequences. (B, ?)
- char_target: Target character sequences. (B, ?)
- Returns:
- : Average sentence-level WER score
- """
- import editdistance
- distances, lens = [], []
- for i, char_pred_i in enumerate(char_pred):
- pred = char_pred_i.replace("▁", " ").split()
- target = char_target[i].replace("▁", " ").split()
- distances.append(editdistance.eval(pred, target))
- lens.append(len(target))
- return float(sum(distances)) / sum(lens)
|