| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211 |
- """RNN decoder module."""
- import logging
- import math
- import random
- from argparse import Namespace
- import numpy as np
- import six
- import torch
- import torch.nn.functional as F
- from funasr.modules.scorers.ctc_prefix_score import CTCPrefixScore
- from funasr.modules.scorers.ctc_prefix_score import CTCPrefixScoreTH
- from funasr.modules.scorers.scorer_interface import ScorerInterface
- from funasr.modules.e2e_asr_common import end_detect
- from funasr.modules.nets_utils import mask_by_length
- from funasr.modules.nets_utils import pad_list
- from funasr.modules.nets_utils import th_accuracy
- from funasr.modules.nets_utils import to_device
- from funasr.modules.rnn.attentions import att_to_numpy
- MAX_DECODER_OUTPUT = 5
- CTC_SCORING_RATIO = 1.5
- class Decoder(torch.nn.Module, ScorerInterface):
- """Decoder module
- :param int eprojs: encoder projection units
- :param int odim: dimension of outputs
- :param str dtype: gru or lstm
- :param int dlayers: decoder layers
- :param int dunits: decoder units
- :param int sos: start of sequence symbol id
- :param int eos: end of sequence symbol id
- :param torch.nn.Module att: attention module
- :param int verbose: verbose level
- :param list char_list: list of character strings
- :param ndarray labeldist: distribution of label smoothing
- :param float lsm_weight: label smoothing weight
- :param float sampling_probability: scheduled sampling probability
- :param float dropout: dropout rate
- :param float context_residual: if True, use context vector for token generation
- :param float replace_sos: use for multilingual (speech/text) translation
- """
- def __init__(
- self,
- eprojs,
- odim,
- dtype,
- dlayers,
- dunits,
- sos,
- eos,
- att,
- verbose=0,
- char_list=None,
- labeldist=None,
- lsm_weight=0.0,
- sampling_probability=0.0,
- dropout=0.0,
- context_residual=False,
- replace_sos=False,
- num_encs=1,
- ):
- torch.nn.Module.__init__(self)
- self.dtype = dtype
- self.dunits = dunits
- self.dlayers = dlayers
- self.context_residual = context_residual
- self.embed = torch.nn.Embedding(odim, dunits)
- self.dropout_emb = torch.nn.Dropout(p=dropout)
- self.decoder = torch.nn.ModuleList()
- self.dropout_dec = torch.nn.ModuleList()
- self.decoder += [
- torch.nn.LSTMCell(dunits + eprojs, dunits)
- if self.dtype == "lstm"
- else torch.nn.GRUCell(dunits + eprojs, dunits)
- ]
- self.dropout_dec += [torch.nn.Dropout(p=dropout)]
- for _ in six.moves.range(1, self.dlayers):
- self.decoder += [
- torch.nn.LSTMCell(dunits, dunits)
- if self.dtype == "lstm"
- else torch.nn.GRUCell(dunits, dunits)
- ]
- self.dropout_dec += [torch.nn.Dropout(p=dropout)]
- # NOTE: dropout is applied only for the vertical connections
- # see https://arxiv.org/pdf/1409.2329.pdf
- self.ignore_id = -1
- if context_residual:
- self.output = torch.nn.Linear(dunits + eprojs, odim)
- else:
- self.output = torch.nn.Linear(dunits, odim)
- self.loss = None
- self.att = att
- self.dunits = dunits
- self.sos = sos
- self.eos = eos
- self.odim = odim
- self.verbose = verbose
- self.char_list = char_list
- # for label smoothing
- self.labeldist = labeldist
- self.vlabeldist = None
- self.lsm_weight = lsm_weight
- self.sampling_probability = sampling_probability
- self.dropout = dropout
- self.num_encs = num_encs
- # for multilingual E2E-ST
- self.replace_sos = replace_sos
- self.logzero = -10000000000.0
- def zero_state(self, hs_pad):
- return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
- def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
- if self.dtype == "lstm":
- z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
- for i in six.moves.range(1, self.dlayers):
- z_list[i], c_list[i] = self.decoder[i](
- self.dropout_dec[i - 1](z_list[i - 1]), (z_prev[i], c_prev[i])
- )
- else:
- z_list[0] = self.decoder[0](ey, z_prev[0])
- for i in six.moves.range(1, self.dlayers):
- z_list[i] = self.decoder[i](
- self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
- )
- return z_list, c_list
- def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
- """Decoder forward
- :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
- [in multi-encoder case,
- list of torch.Tensor,
- [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
- :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
- [in multi-encoder case, list of torch.Tensor,
- [(B), (B), ..., ]
- :param torch.Tensor ys_pad: batch of padded character id sequence tensor
- (B, Lmax)
- :param int strm_idx: stream index indicates the index of decoding stream.
- :param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
- :return: attention loss value
- :rtype: torch.Tensor
- :return: accuracy
- :rtype: float
- """
- # to support mutiple encoder asr mode, in single encoder mode,
- # convert torch.Tensor to List of torch.Tensor
- if self.num_encs == 1:
- hs_pad = [hs_pad]
- hlens = [hlens]
- # TODO(kan-bayashi): need to make more smart way
- ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
- # attention index for the attention module
- # in SPA (speaker parallel attention),
- # att_idx is used to select attention module. In other cases, it is 0.
- att_idx = min(strm_idx, len(self.att) - 1)
- # hlens should be list of list of integer
- hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
- self.loss = None
- # prepare input and output word sequences with sos/eos IDs
- eos = ys[0].new([self.eos])
- sos = ys[0].new([self.sos])
- if self.replace_sos:
- ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
- else:
- ys_in = [torch.cat([sos, y], dim=0) for y in ys]
- ys_out = [torch.cat([y, eos], dim=0) for y in ys]
- # padding for ys with -1
- # pys: utt x olen
- ys_in_pad = pad_list(ys_in, self.eos)
- ys_out_pad = pad_list(ys_out, self.ignore_id)
- # get dim, length info
- batch = ys_out_pad.size(0)
- olength = ys_out_pad.size(1)
- for idx in range(self.num_encs):
- logging.info(
- self.__class__.__name__
- + "Number of Encoder:{}; enc{}: input lengths: {}.".format(
- self.num_encs, idx + 1, hlens[idx]
- )
- )
- logging.info(
- self.__class__.__name__
- + " output lengths: "
- + str([y.size(0) for y in ys_out])
- )
- # initialization
- c_list = [self.zero_state(hs_pad[0])]
- z_list = [self.zero_state(hs_pad[0])]
- for _ in six.moves.range(1, self.dlayers):
- c_list.append(self.zero_state(hs_pad[0]))
- z_list.append(self.zero_state(hs_pad[0]))
- z_all = []
- if self.num_encs == 1:
- att_w = None
- self.att[att_idx].reset() # reset pre-computation of h
- else:
- att_w_list = [None] * (self.num_encs + 1) # atts + han
- att_c_list = [None] * (self.num_encs) # atts
- for idx in range(self.num_encs + 1):
- self.att[idx].reset() # reset pre-computation of h in atts and han
- # pre-computation of embedding
- eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
- # loop for an output sequence
- for i in six.moves.range(olength):
- if self.num_encs == 1:
- att_c, att_w = self.att[att_idx](
- hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
- )
- else:
- for idx in range(self.num_encs):
- att_c_list[idx], att_w_list[idx] = self.att[idx](
- hs_pad[idx],
- hlens[idx],
- self.dropout_dec[0](z_list[0]),
- att_w_list[idx],
- )
- hs_pad_han = torch.stack(att_c_list, dim=1)
- hlens_han = [self.num_encs] * len(ys_in)
- att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
- hs_pad_han,
- hlens_han,
- self.dropout_dec[0](z_list[0]),
- att_w_list[self.num_encs],
- )
- if i > 0 and random.random() < self.sampling_probability:
- logging.info(" scheduled sampling ")
- z_out = self.output(z_all[-1])
- z_out = np.argmax(z_out.detach().cpu(), axis=1)
- z_out = self.dropout_emb(self.embed(to_device(hs_pad[0], z_out)))
- ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
- else:
- ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
- z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
- if self.context_residual:
- z_all.append(
- torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
- ) # utt x (zdim + hdim)
- else:
- z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
- z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
- # compute loss
- y_all = self.output(z_all)
- self.loss = F.cross_entropy(
- y_all,
- ys_out_pad.view(-1),
- ignore_index=self.ignore_id,
- reduction="mean",
- )
- # compute perplexity
- ppl = math.exp(self.loss.item())
- # -1: eos, which is removed in the loss computation
- self.loss *= np.mean([len(x) for x in ys_in]) - 1
- acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
- logging.info("att loss:" + "".join(str(self.loss.item()).split("\n")))
- # show predicted character sequence for debug
- if self.verbose > 0 and self.char_list is not None:
- ys_hat = y_all.view(batch, olength, -1)
- ys_true = ys_out_pad
- for (i, y_hat), y_true in zip(
- enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
- ):
- if i == MAX_DECODER_OUTPUT:
- break
- idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
- idx_true = y_true[y_true != self.ignore_id]
- seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
- seq_true = [self.char_list[int(idx)] for idx in idx_true]
- seq_hat = "".join(seq_hat)
- seq_true = "".join(seq_true)
- logging.info("groundtruth[%d]: " % i + seq_true)
- logging.info("prediction [%d]: " % i + seq_hat)
- if self.labeldist is not None:
- if self.vlabeldist is None:
- self.vlabeldist = to_device(hs_pad[0], torch.from_numpy(self.labeldist))
- loss_reg = -torch.sum(
- (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0
- ) / len(ys_in)
- self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
- return self.loss, acc, ppl
- def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
- """beam search implementation
- :param torch.Tensor h: encoder hidden state (T, eprojs)
- [in multi-encoder case, list of torch.Tensor,
- [(T1, eprojs), (T2, eprojs), ...] ]
- :param torch.Tensor lpz: ctc log softmax output (T, odim)
- [in multi-encoder case, list of torch.Tensor,
- [(T1, odim), (T2, odim), ...] ]
- :param Namespace recog_args: argument Namespace containing options
- :param char_list: list of character strings
- :param torch.nn.Module rnnlm: language module
- :param int strm_idx:
- stream index for speaker parallel attention in multi-speaker case
- :return: N-best decoding results
- :rtype: list of dicts
- """
- # to support mutiple encoder asr mode, in single encoder mode,
- # convert torch.Tensor to List of torch.Tensor
- if self.num_encs == 1:
- h = [h]
- lpz = [lpz]
- if self.num_encs > 1 and lpz is None:
- lpz = [lpz] * self.num_encs
- for idx in range(self.num_encs):
- logging.info(
- "Number of Encoder:{}; enc{}: input lengths: {}.".format(
- self.num_encs, idx + 1, h[0].size(0)
- )
- )
- att_idx = min(strm_idx, len(self.att) - 1)
- # initialization
- c_list = [self.zero_state(h[0].unsqueeze(0))]
- z_list = [self.zero_state(h[0].unsqueeze(0))]
- for _ in six.moves.range(1, self.dlayers):
- c_list.append(self.zero_state(h[0].unsqueeze(0)))
- z_list.append(self.zero_state(h[0].unsqueeze(0)))
- if self.num_encs == 1:
- a = None
- self.att[att_idx].reset() # reset pre-computation of h
- else:
- a = [None] * (self.num_encs + 1) # atts + han
- att_w_list = [None] * (self.num_encs + 1) # atts + han
- att_c_list = [None] * (self.num_encs) # atts
- for idx in range(self.num_encs + 1):
- self.att[idx].reset() # reset pre-computation of h in atts and han
- # search parms
- beam = recog_args.beam_size
- penalty = recog_args.penalty
- ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
- if lpz[0] is not None and self.num_encs > 1:
- # weights-ctc,
- # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
- weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
- recog_args.weights_ctc_dec
- ) # normalize
- logging.info(
- "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
- )
- else:
- weights_ctc_dec = [1.0]
- # preprate sos
- if self.replace_sos and recog_args.tgt_lang:
- y = char_list.index(recog_args.tgt_lang)
- else:
- y = self.sos
- logging.info("<sos> index: " + str(y))
- logging.info("<sos> mark: " + char_list[y])
- vy = h[0].new_zeros(1).long()
- maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)])
- if recog_args.maxlenratio != 0:
- # maxlen >= 1
- maxlen = max(1, int(recog_args.maxlenratio * maxlen))
- minlen = int(recog_args.minlenratio * maxlen)
- logging.info("max output length: " + str(maxlen))
- logging.info("min output length: " + str(minlen))
- # initialize hypothesis
- if rnnlm:
- hyp = {
- "score": 0.0,
- "yseq": [y],
- "c_prev": c_list,
- "z_prev": z_list,
- "a_prev": a,
- "rnnlm_prev": None,
- }
- else:
- hyp = {
- "score": 0.0,
- "yseq": [y],
- "c_prev": c_list,
- "z_prev": z_list,
- "a_prev": a,
- }
- if lpz[0] is not None:
- ctc_prefix_score = [
- CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np)
- for idx in range(self.num_encs)
- ]
- hyp["ctc_state_prev"] = [
- ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs)
- ]
- hyp["ctc_score_prev"] = [0.0] * self.num_encs
- if ctc_weight != 1.0:
- # pre-pruning based on attention scores
- ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO))
- else:
- ctc_beam = lpz[0].shape[-1]
- hyps = [hyp]
- ended_hyps = []
- for i in six.moves.range(maxlen):
- logging.debug("position " + str(i))
- hyps_best_kept = []
- for hyp in hyps:
- vy[0] = hyp["yseq"][i]
- ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
- if self.num_encs == 1:
- att_c, att_w = self.att[att_idx](
- h[0].unsqueeze(0),
- [h[0].size(0)],
- self.dropout_dec[0](hyp["z_prev"][0]),
- hyp["a_prev"],
- )
- else:
- for idx in range(self.num_encs):
- att_c_list[idx], att_w_list[idx] = self.att[idx](
- h[idx].unsqueeze(0),
- [h[idx].size(0)],
- self.dropout_dec[0](hyp["z_prev"][0]),
- hyp["a_prev"][idx],
- )
- h_han = torch.stack(att_c_list, dim=1)
- att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
- h_han,
- [self.num_encs],
- self.dropout_dec[0](hyp["z_prev"][0]),
- hyp["a_prev"][self.num_encs],
- )
- ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
- z_list, c_list = self.rnn_forward(
- ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]
- )
- # get nbest local scores and their ids
- if self.context_residual:
- logits = self.output(
- torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
- )
- else:
- logits = self.output(self.dropout_dec[-1](z_list[-1]))
- local_att_scores = F.log_softmax(logits, dim=1)
- if rnnlm:
- rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
- local_scores = (
- local_att_scores + recog_args.lm_weight * local_lm_scores
- )
- else:
- local_scores = local_att_scores
- if lpz[0] is not None:
- local_best_scores, local_best_ids = torch.topk(
- local_att_scores, ctc_beam, dim=1
- )
- ctc_scores, ctc_states = (
- [None] * self.num_encs,
- [None] * self.num_encs,
- )
- for idx in range(self.num_encs):
- ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
- hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
- )
- local_scores = (1.0 - ctc_weight) * local_att_scores[
- :, local_best_ids[0]
- ]
- if self.num_encs == 1:
- local_scores += ctc_weight * torch.from_numpy(
- ctc_scores[0] - hyp["ctc_score_prev"][0]
- )
- else:
- for idx in range(self.num_encs):
- local_scores += (
- ctc_weight
- * weights_ctc_dec[idx]
- * torch.from_numpy(
- ctc_scores[idx] - hyp["ctc_score_prev"][idx]
- )
- )
- if rnnlm:
- local_scores += (
- recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
- )
- local_best_scores, joint_best_ids = torch.topk(
- local_scores, beam, dim=1
- )
- local_best_ids = local_best_ids[:, joint_best_ids[0]]
- else:
- local_best_scores, local_best_ids = torch.topk(
- local_scores, beam, dim=1
- )
- for j in six.moves.range(beam):
- new_hyp = {}
- # [:] is needed!
- new_hyp["z_prev"] = z_list[:]
- new_hyp["c_prev"] = c_list[:]
- if self.num_encs == 1:
- new_hyp["a_prev"] = att_w[:]
- else:
- new_hyp["a_prev"] = [
- att_w_list[idx][:] for idx in range(self.num_encs + 1)
- ]
- new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
- new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
- new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
- new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j])
- if rnnlm:
- new_hyp["rnnlm_prev"] = rnnlm_state
- if lpz[0] is not None:
- new_hyp["ctc_state_prev"] = [
- ctc_states[idx][joint_best_ids[0, j]]
- for idx in range(self.num_encs)
- ]
- new_hyp["ctc_score_prev"] = [
- ctc_scores[idx][joint_best_ids[0, j]]
- for idx in range(self.num_encs)
- ]
- # will be (2 x beam) hyps at most
- hyps_best_kept.append(new_hyp)
- hyps_best_kept = sorted(
- hyps_best_kept, key=lambda x: x["score"], reverse=True
- )[:beam]
- # sort and get nbest
- hyps = hyps_best_kept
- logging.debug("number of pruned hypotheses: " + str(len(hyps)))
- logging.debug(
- "best hypo: "
- + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
- )
- # add eos in the final loop to avoid that there are no ended hyps
- if i == maxlen - 1:
- logging.info("adding <eos> in the last position in the loop")
- for hyp in hyps:
- hyp["yseq"].append(self.eos)
- # add ended hypotheses to a final list,
- # and removed them from current hypotheses
- # (this will be a problem, number of hyps < beam)
- remained_hyps = []
- for hyp in hyps:
- if hyp["yseq"][-1] == self.eos:
- # only store the sequence that has more than minlen outputs
- # also add penalty
- if len(hyp["yseq"]) > minlen:
- hyp["score"] += (i + 1) * penalty
- if rnnlm: # Word LM needs to add final <eos> score
- hyp["score"] += recog_args.lm_weight * rnnlm.final(
- hyp["rnnlm_prev"]
- )
- ended_hyps.append(hyp)
- else:
- remained_hyps.append(hyp)
- # end detection
- if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
- logging.info("end detected at %d", i)
- break
- hyps = remained_hyps
- if len(hyps) > 0:
- logging.debug("remaining hypotheses: " + str(len(hyps)))
- else:
- logging.info("no hypothesis. Finish decoding.")
- break
- for hyp in hyps:
- logging.debug(
- "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
- )
- logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
- nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
- : min(len(ended_hyps), recog_args.nbest)
- ]
- # check number of hypotheses
- if len(nbest_hyps) == 0:
- logging.warning(
- "there is no N-best results, "
- "perform recognition again with smaller minlenratio."
- )
- # should copy because Namespace will be overwritten globally
- recog_args = Namespace(**vars(recog_args))
- recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
- if self.num_encs == 1:
- return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm)
- else:
- return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
- logging.info("total log probability: " + str(nbest_hyps[0]["score"]))
- logging.info(
- "normalized log probability: "
- + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))
- )
- # remove sos
- return nbest_hyps
- def recognize_beam_batch(
- self,
- h,
- hlens,
- lpz,
- recog_args,
- char_list,
- rnnlm=None,
- normalize_score=True,
- strm_idx=0,
- lang_ids=None,
- ):
- # to support mutiple encoder asr mode, in single encoder mode,
- # convert torch.Tensor to List of torch.Tensor
- if self.num_encs == 1:
- h = [h]
- hlens = [hlens]
- lpz = [lpz]
- if self.num_encs > 1 and lpz is None:
- lpz = [lpz] * self.num_encs
- att_idx = min(strm_idx, len(self.att) - 1)
- for idx in range(self.num_encs):
- logging.info(
- "Number of Encoder:{}; enc{}: input lengths: {}.".format(
- self.num_encs, idx + 1, h[idx].size(1)
- )
- )
- h[idx] = mask_by_length(h[idx], hlens[idx], 0.0)
- # search params
- batch = len(hlens[0])
- beam = recog_args.beam_size
- penalty = recog_args.penalty
- ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT
- att_weight = 1.0 - ctc_weight
- ctc_margin = getattr(
- recog_args, "ctc_window_margin", 0
- ) # use getattr to keep compatibility
- # weights-ctc,
- # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
- if lpz[0] is not None and self.num_encs > 1:
- weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
- recog_args.weights_ctc_dec
- ) # normalize
- logging.info(
- "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
- )
- else:
- weights_ctc_dec = [1.0]
- n_bb = batch * beam
- pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1)
- max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)])
- if recog_args.maxlenratio == 0:
- maxlen = max_hlen
- else:
- maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
- minlen = int(recog_args.minlenratio * max_hlen)
- logging.info("max output length: " + str(maxlen))
- logging.info("min output length: " + str(minlen))
- # initialization
- c_prev = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- z_prev = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- c_list = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- z_list = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- vscores = to_device(h[0], torch.zeros(batch, beam))
- rnnlm_state = None
- if self.num_encs == 1:
- a_prev = [None]
- att_w_list, ctc_scorer, ctc_state = [None], [None], [None]
- self.att[att_idx].reset() # reset pre-computation of h
- else:
- a_prev = [None] * (self.num_encs + 1) # atts + han
- att_w_list = [None] * (self.num_encs + 1) # atts + han
- att_c_list = [None] * (self.num_encs) # atts
- ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs)
- for idx in range(self.num_encs + 1):
- self.att[idx].reset() # reset pre-computation of h in atts and han
- if self.replace_sos and recog_args.tgt_lang:
- logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
- logging.info("<sos> mark: " + recog_args.tgt_lang)
- yseq = [
- [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)
- ]
- elif lang_ids is not None:
- # NOTE: used for evaluation during training
- yseq = [
- [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)
- ]
- else:
- logging.info("<sos> index: " + str(self.sos))
- logging.info("<sos> mark: " + char_list[self.sos])
- yseq = [[self.sos] for _ in six.moves.range(n_bb)]
- accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
- stop_search = [False for _ in six.moves.range(batch)]
- nbest_hyps = [[] for _ in six.moves.range(batch)]
- ended_hyps = [[] for _ in range(batch)]
- exp_hlens = [
- hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous()
- for idx in range(self.num_encs)
- ]
- exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
- exp_h = [
- h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
- for idx in range(self.num_encs)
- ]
- exp_h = [
- exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
- for idx in range(self.num_encs)
- ]
- if lpz[0] is not None:
- scoring_num = min(
- int(beam * CTC_SCORING_RATIO)
- if att_weight > 0.0 and not lpz[0].is_cuda
- else 0,
- lpz[0].size(-1),
- )
- ctc_scorer = [
- CTCPrefixScoreTH(
- lpz[idx],
- hlens[idx],
- 0,
- self.eos,
- margin=ctc_margin,
- )
- for idx in range(self.num_encs)
- ]
- for i in six.moves.range(maxlen):
- logging.debug("position " + str(i))
- vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq)))
- ey = self.dropout_emb(self.embed(vy))
- if self.num_encs == 1:
- att_c, att_w = self.att[att_idx](
- exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0]
- )
- att_w_list = [att_w]
- else:
- for idx in range(self.num_encs):
- att_c_list[idx], att_w_list[idx] = self.att[idx](
- exp_h[idx],
- exp_hlens[idx],
- self.dropout_dec[0](z_prev[0]),
- a_prev[idx],
- )
- exp_h_han = torch.stack(att_c_list, dim=1)
- att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
- exp_h_han,
- [self.num_encs] * n_bb,
- self.dropout_dec[0](z_prev[0]),
- a_prev[self.num_encs],
- )
- ey = torch.cat((ey, att_c), dim=1)
- # attention decoder
- z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
- if self.context_residual:
- logits = self.output(
- torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
- )
- else:
- logits = self.output(self.dropout_dec[-1](z_list[-1]))
- local_scores = att_weight * F.log_softmax(logits, dim=1)
- # rnnlm
- if rnnlm:
- rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
- local_scores = local_scores + recog_args.lm_weight * local_lm_scores
- # ctc
- if ctc_scorer[0]:
- local_scores[:, 0] = self.logzero # avoid choosing blank
- part_ids = (
- torch.topk(local_scores, scoring_num, dim=-1)[1]
- if scoring_num > 0
- else None
- )
- for idx in range(self.num_encs):
- att_w = att_w_list[idx]
- att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
- local_ctc_scores, ctc_state[idx] = ctc_scorer[idx](
- yseq, ctc_state[idx], part_ids, att_w_
- )
- local_scores = (
- local_scores
- + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
- )
- local_scores = local_scores.view(batch, beam, self.odim)
- if i == 0:
- local_scores[:, 1:, :] = self.logzero
- # accumulate scores
- eos_vscores = local_scores[:, :, self.eos] + vscores
- vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
- vscores[:, :, self.eos] = self.logzero
- vscores = (vscores + local_scores).view(batch, -1)
- # global pruning
- accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
- accum_odim_ids = (
- torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
- )
- accum_padded_beam_ids = (
- (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
- )
- y_prev = yseq[:][:]
- yseq = self._index_select_list(yseq, accum_padded_beam_ids)
- yseq = self._append_ids(yseq, accum_odim_ids)
- vscores = accum_best_scores
- vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids))
- a_prev = []
- num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1
- for idx in range(num_atts):
- if isinstance(att_w_list[idx], torch.Tensor):
- _a_prev = torch.index_select(
- att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx
- )
- elif isinstance(att_w_list[idx], list):
- # handle the case of multi-head attention
- _a_prev = [
- torch.index_select(att_w_one.view(n_bb, -1), 0, vidx)
- for att_w_one in att_w_list[idx]
- ]
- else:
- # handle the case of location_recurrent when return is a tuple
- _a_prev_ = torch.index_select(
- att_w_list[idx][0].view(n_bb, -1), 0, vidx
- )
- _h_prev_ = torch.index_select(
- att_w_list[idx][1][0].view(n_bb, -1), 0, vidx
- )
- _c_prev_ = torch.index_select(
- att_w_list[idx][1][1].view(n_bb, -1), 0, vidx
- )
- _a_prev = (_a_prev_, (_h_prev_, _c_prev_))
- a_prev.append(_a_prev)
- z_prev = [
- torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
- for li in range(self.dlayers)
- ]
- c_prev = [
- torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
- for li in range(self.dlayers)
- ]
- # pick ended hyps
- if i >= minlen:
- k = 0
- penalty_i = (i + 1) * penalty
- thr = accum_best_scores[:, -1]
- for samp_i in six.moves.range(batch):
- if stop_search[samp_i]:
- k = k + beam
- continue
- for beam_j in six.moves.range(beam):
- _vscore = None
- if eos_vscores[samp_i, beam_j] > thr[samp_i]:
- yk = y_prev[k][:]
- if len(yk) <= min(
- hlens[idx][samp_i] for idx in range(self.num_encs)
- ):
- _vscore = eos_vscores[samp_i][beam_j] + penalty_i
- elif i == maxlen - 1:
- yk = yseq[k][:]
- _vscore = vscores[samp_i][beam_j] + penalty_i
- if _vscore:
- yk.append(self.eos)
- if rnnlm:
- _vscore += recog_args.lm_weight * rnnlm.final(
- rnnlm_state, index=k
- )
- _score = _vscore.data.cpu().numpy()
- ended_hyps[samp_i].append(
- {"yseq": yk, "vscore": _vscore, "score": _score}
- )
- k = k + 1
- # end detection
- stop_search = [
- stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
- for samp_i in six.moves.range(batch)
- ]
- stop_search_summary = list(set(stop_search))
- if len(stop_search_summary) == 1 and stop_search_summary[0]:
- break
- if rnnlm:
- rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
- if ctc_scorer[0]:
- for idx in range(self.num_encs):
- ctc_state[idx] = ctc_scorer[idx].index_select_state(
- ctc_state[idx], accum_best_ids
- )
- torch.cuda.empty_cache()
- dummy_hyps = [
- {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}
- ]
- ended_hyps = [
- ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
- for samp_i in six.moves.range(batch)
- ]
- if normalize_score:
- for samp_i in six.moves.range(batch):
- for x in ended_hyps[samp_i]:
- x["score"] /= len(x["yseq"])
- nbest_hyps = [
- sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
- : min(len(ended_hyps[samp_i]), recog_args.nbest)
- ]
- for samp_i in six.moves.range(batch)
- ]
- return nbest_hyps
- def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None):
- """Calculate all of attentions
- :param torch.Tensor hs_pad: batch of padded hidden state sequences
- (B, Tmax, D)
- in multi-encoder case, list of torch.Tensor,
- [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
- :param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
- [in multi-encoder case, list of torch.Tensor,
- [(B), (B), ..., ]
- :param torch.Tensor ys_pad:
- batch of padded character id sequence tensor (B, Lmax)
- :param int strm_idx:
- stream index for parallel speaker attention in multi-speaker case
- :param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
- :return: attention weights with the following shape,
- 1) multi-head case => attention weights (B, H, Lmax, Tmax),
- 2) multi-encoder case =>
- [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
- 3) other case => attention weights (B, Lmax, Tmax).
- :rtype: float ndarray
- """
- # to support mutiple encoder asr mode, in single encoder mode,
- # convert torch.Tensor to List of torch.Tensor
- if self.num_encs == 1:
- hs_pad = [hs_pad]
- hlen = [hlen]
- # TODO(kan-bayashi): need to make more smart way
- ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
- att_idx = min(strm_idx, len(self.att) - 1)
- # hlen should be list of list of integer
- hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)]
- self.loss = None
- # prepare input and output word sequences with sos/eos IDs
- eos = ys[0].new([self.eos])
- sos = ys[0].new([self.sos])
- if self.replace_sos:
- ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
- else:
- ys_in = [torch.cat([sos, y], dim=0) for y in ys]
- ys_out = [torch.cat([y, eos], dim=0) for y in ys]
- # padding for ys with -1
- # pys: utt x olen
- ys_in_pad = pad_list(ys_in, self.eos)
- ys_out_pad = pad_list(ys_out, self.ignore_id)
- # get length info
- olength = ys_out_pad.size(1)
- # initialization
- c_list = [self.zero_state(hs_pad[0])]
- z_list = [self.zero_state(hs_pad[0])]
- for _ in six.moves.range(1, self.dlayers):
- c_list.append(self.zero_state(hs_pad[0]))
- z_list.append(self.zero_state(hs_pad[0]))
- att_ws = []
- if self.num_encs == 1:
- att_w = None
- self.att[att_idx].reset() # reset pre-computation of h
- else:
- att_w_list = [None] * (self.num_encs + 1) # atts + han
- att_c_list = [None] * (self.num_encs) # atts
- for idx in range(self.num_encs + 1):
- self.att[idx].reset() # reset pre-computation of h in atts and han
- # pre-computation of embedding
- eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
- # loop for an output sequence
- for i in six.moves.range(olength):
- if self.num_encs == 1:
- att_c, att_w = self.att[att_idx](
- hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w
- )
- att_ws.append(att_w)
- else:
- for idx in range(self.num_encs):
- att_c_list[idx], att_w_list[idx] = self.att[idx](
- hs_pad[idx],
- hlen[idx],
- self.dropout_dec[0](z_list[0]),
- att_w_list[idx],
- )
- hs_pad_han = torch.stack(att_c_list, dim=1)
- hlen_han = [self.num_encs] * len(ys_in)
- att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
- hs_pad_han,
- hlen_han,
- self.dropout_dec[0](z_list[0]),
- att_w_list[self.num_encs],
- )
- att_ws.append(att_w_list.copy())
- ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
- z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
- if self.num_encs == 1:
- # convert to numpy array with the shape (B, Lmax, Tmax)
- att_ws = att_to_numpy(att_ws, self.att[att_idx])
- else:
- _att_ws = []
- for idx, ws in enumerate(zip(*att_ws)):
- ws = att_to_numpy(ws, self.att[idx])
- _att_ws.append(ws)
- att_ws = _att_ws
- return att_ws
- @staticmethod
- def _get_last_yseq(exp_yseq):
- last = []
- for y_seq in exp_yseq:
- last.append(y_seq[-1])
- return last
- @staticmethod
- def _append_ids(yseq, ids):
- if isinstance(ids, list):
- for i, j in enumerate(ids):
- yseq[i].append(j)
- else:
- for i in range(len(yseq)):
- yseq[i].append(ids)
- return yseq
- @staticmethod
- def _index_select_list(yseq, lst):
- new_yseq = []
- for i in lst:
- new_yseq.append(yseq[i][:])
- return new_yseq
- @staticmethod
- def _index_select_lm_state(rnnlm_state, dim, vidx):
- if isinstance(rnnlm_state, dict):
- new_state = {}
- for k, v in rnnlm_state.items():
- new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
- elif isinstance(rnnlm_state, list):
- new_state = []
- for i in vidx:
- new_state.append(rnnlm_state[int(i)][:])
- return new_state
- # scorer interface methods
- def init_state(self, x):
- # to support mutiple encoder asr mode, in single encoder mode,
- # convert torch.Tensor to List of torch.Tensor
- if self.num_encs == 1:
- x = [x]
- c_list = [self.zero_state(x[0].unsqueeze(0))]
- z_list = [self.zero_state(x[0].unsqueeze(0))]
- for _ in six.moves.range(1, self.dlayers):
- c_list.append(self.zero_state(x[0].unsqueeze(0)))
- z_list.append(self.zero_state(x[0].unsqueeze(0)))
- # TODO(karita): support strm_index for `asr_mix`
- strm_index = 0
- att_idx = min(strm_index, len(self.att) - 1)
- if self.num_encs == 1:
- a = None
- self.att[att_idx].reset() # reset pre-computation of h
- else:
- a = [None] * (self.num_encs + 1) # atts + han
- for idx in range(self.num_encs + 1):
- self.att[idx].reset() # reset pre-computation of h in atts and han
- return dict(
- c_prev=c_list[:],
- z_prev=z_list[:],
- a_prev=a,
- workspace=(att_idx, z_list, c_list),
- )
- def score(self, yseq, state, x):
- # to support mutiple encoder asr mode, in single encoder mode,
- # convert torch.Tensor to List of torch.Tensor
- if self.num_encs == 1:
- x = [x]
- att_idx, z_list, c_list = state["workspace"]
- vy = yseq[-1].unsqueeze(0)
- ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
- if self.num_encs == 1:
- att_c, att_w = self.att[att_idx](
- x[0].unsqueeze(0),
- [x[0].size(0)],
- self.dropout_dec[0](state["z_prev"][0]),
- state["a_prev"],
- )
- else:
- att_w = [None] * (self.num_encs + 1) # atts + han
- att_c_list = [None] * (self.num_encs) # atts
- for idx in range(self.num_encs):
- att_c_list[idx], att_w[idx] = self.att[idx](
- x[idx].unsqueeze(0),
- [x[idx].size(0)],
- self.dropout_dec[0](state["z_prev"][0]),
- state["a_prev"][idx],
- )
- h_han = torch.stack(att_c_list, dim=1)
- att_c, att_w[self.num_encs] = self.att[self.num_encs](
- h_han,
- [self.num_encs],
- self.dropout_dec[0](state["z_prev"][0]),
- state["a_prev"][self.num_encs],
- )
- ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
- z_list, c_list = self.rnn_forward(
- ey, z_list, c_list, state["z_prev"], state["c_prev"]
- )
- if self.context_residual:
- logits = self.output(
- torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
- )
- else:
- logits = self.output(self.dropout_dec[-1](z_list[-1]))
- logp = F.log_softmax(logits, dim=1).squeeze(0)
- return (
- logp,
- dict(
- c_prev=c_list[:],
- z_prev=z_list[:],
- a_prev=att_w,
- workspace=(att_idx, z_list, c_list),
- ),
- )
- def decoder_for(args, odim, sos, eos, att, labeldist):
- return Decoder(
- args.eprojs,
- odim,
- args.dtype,
- args.dlayers,
- args.dunits,
- sos,
- eos,
- att,
- args.verbose,
- args.char_list,
- labeldist,
- args.lsm_weight,
- args.sampling_probability,
- args.dropout_rate_decoder,
- getattr(args, "context_residual", False), # use getattr to keep compatibility
- getattr(args, "replace_sos", False), # use getattr to keep compatibility
- getattr(args, "num_encs", 1),
- ) # use getattr to keep compatibility
|